import random

from dowel import logger, tabular
from torch.optim import Optimizer

import torch


def decision(probability):
    return random.random() < probability


class PAGEPGOptimizer(Optimizer):
    def __init__(self, params, eta=0.001, pt=0.4):

        self.eta = eta
        self.pt = pt
        self.iteration = -1
        self.sqr_grads_norms = 0
        self.last_grad_norm = 0
        defaults = dict()
        super(PAGEPGOptimizer, self).__init__(params, defaults)
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['last_point'] = torch.zeros_like(p)
                state['current_point'] = torch.zeros_like(p)

    def compute_norm_of_list_var(self, array_):
        """
        Args:
        param array_: list of tensors
        return:
        norm of the flattened list
        """
        norm_square = 0
        for i in range(len(array_)):
            norm_square += array_[i].norm(2).item() ** 2
        return norm_square ** 0.5

    def inner_product_of_list_var(self, array1_, array2_):

        """
        Args:
        param array1_: list of tensors
        param array2_: list of tensors
        return:
        The inner product of the flattened list
        """

        sum_list = 0
        for i in range(len(array1_)):
            sum_list += torch.sum(array1_[i] * array2_[i])
        return sum_list

    def update_model_to_last_point(self, ):
        """
        update the parameter based on the displacement
        """

        for group in self.param_groups:
            with torch.no_grad():
                for p in group['params']:
                    state = self.state[p]
                    state['current_point'].copy_(p)
                    p.copy_(state['last_point'])

    def update_model_to_current_point(self, ):
        """
        update the parameter based on the displacement
        """

        for group in self.param_groups:
            with torch.no_grad():
                for p in group['params']:
                    state = self.state[p]
                    p.copy_(state['current_point'])

    def update_parameters(self, ):

        for group in self.param_groups:
            with torch.no_grad():
                for p in group['params']:
                    state = self.state[p]
                    buf = state['momentum_buffer']
                    last_point, current_point = state['last_point'], state['current_point']
                    p.copy_(current_point - self.eta * buf)
                    last_point.copy_(current_point)

    def step(self, closure=None):
        """Performs a single optimization step.
        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        self.iteration += 1
        g_square_norm = 0
        grad_square_norm = 0
        vector = []
        grads = []
        g = []
        param = []
        modified_grads = []
        for group in self.param_groups:

            self.update_model_to_last_point()
            with torch.enable_grad():
                closure()

            for p in group['params']:
                modified_grads.append(p.grad.detach().clone())

            self.update_model_to_current_point()
            with torch.enable_grad():
                g_ll = closure(current_point=True)

            for p in group['params']:
                vector.append((self.state[p]['current_point'].detach() - self.state[p]['last_point'].detach()))
                grads.append(p.grad.clone())
                param.append(p)
            # compute gradiant vector
            i = 0
            for p in group['params']:
                state = self.state[p]
                d_p = grads[i]
                grad_square_norm += d_p.norm(2).item() ** 2

                if 'momentum_buffer' not in state or decision(self.pt):
                    buf = state['momentum_buffer'] = d_p
                else:
                    buf = state['momentum_buffer'].detach()
                    buf.add_(d_p - modified_grads[i])

                g.append(buf)
                g_square_norm += buf.norm(2).item() ** 2
                i += 1

        # store square of grad norm
        self.sqr_grads_norms += self.last_grad_norm
        self.last_grad_norm = grad_square_norm
        self.update_parameters()

        with tabular.prefix("PAGEPG" + '/'):
            tabular.record('norm of gradient', grad_square_norm ** (1. / 2))
            logger.log(tabular)
        return None
